/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
 * \file kernel_common.h
 * \brief
 */
#ifndef ASCENDC_KERNEL_COMMON_H
#define ASCENDC_KERNEL_COMMON_H

#include "kernel_reg.h"
#include "kernel_process_lock.h"
#include "kernel_operator_tensor_trait.h"
#include "kernel_operator_cache_intf.h"
#include "kernel_operator_block_sync_intf.h"
#include "kernel_operator_sys_var_intf.h"
#include "kernel_operator_swap_mem_intf.h"

namespace AscendC {
class TPipe;
class KfcCommClient;
} // namespace AscendC

#if __NPU_ARCH__ == 2201 || (__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102)
#if defined(__ASCENDC_SUPERKERNEL_EARLY_START_V1) || defined(__ASCENDC_SUPERKERNEL_EARLY_START_V2)
__BLOCK_LOCAL__ __inline__ uint32_t g_super_kernel_early_start_config;
#endif
#ifdef __SUPER_KERNEL_DYNAMIC_BLOCK_NUM__
__BLOCK_LOCAL__ __inline__ uint32_t g_super_kernel_dynamic_block_num;
#endif
#ifdef SPLIT_CORE_CUBE
__BLOCK_LOCAL__ __inline__ AscendC::TPipe* g_cubeTPipePtr;
#elif defined(SPLIT_CORE_VEC)
__BLOCK_LOCAL__ __inline__ AscendC::TPipe* g_vecTPipePtr;
#else
__BLOCK_LOCAL__ __inline__ AscendC::TPipe* g_tPipePtr;
#endif
#else
__BLOCK_LOCAL__ __inline__ AscendC::TPipe* g_tPipePtr;
#endif

#if __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 3102 || (__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102)
__BLOCK_LOCAL__ __inline__ uint64_t g_maskCount;
__BLOCK_LOCAL__ __inline__ half g_deqValue;
#endif
#if __NPU_ARCH__ == 2201
__BLOCK_LOCAL__ __inline__ half g_deqValue;
#endif

__BLOCK_LOCAL__ __inline__ __gm__ uint8_t* g_dumpWorkspaceReserved;
__BLOCK_LOCAL__ __inline__ __gm__ uint8_t* g_hcclContextReserved[2];

#if defined(UT_TEST) || defined(ST_TEST)
__aicore__ AscendC::TPipe* GetTPipePtr();
#else
__aicore__ inline AscendC::TPipe* GetTPipePtr()
{
#if __NPU_ARCH__ == 2201 || __NPU_ARCH__ == 3101
#ifdef SPLIT_CORE_CUBE
    return g_cubeTPipePtr;
#elif defined(SPLIT_CORE_VEC)
    return g_vecTPipePtr;
#else
    return g_tPipePtr;
#endif
#else
    return g_tPipePtr;
#endif
}
#endif

namespace AscendC {
template <typename T, MaskMode mode = MaskMode::NORMAL>
__aicore__ static inline void SetVectorMask(const uint64_t maskHigh, const uint64_t maskLow)
{
#if __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 3102
    if (mode == MaskMode::COUNTER) {
        g_maskCount = maskLow;
    }
#endif
    SetVectorMaskImpl<T, mode>(maskHigh, maskLow);
}

template <typename T, MaskMode mode = MaskMode::NORMAL>
__aicore__ static inline void SetVectorMask(int32_t len)
{
#if __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 3102
    g_maskCount = len;
#endif
    SetVectorMaskImpl<T, mode>(len);
}

__aicore__ inline void ResetMask()
{
#if __NPU_ARCH__ == 3002 || __NPU_ARCH__ == 3102
    g_maskCount = 0;
#endif
    ResetMaskImpl();
}

#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102))
using MutexID = uint8_t;

class Mutex {
public:
    template <pipe_t pipe>
    static __aicore__ inline void Lock(MutexID id)
    {
        ASCENDC_ASSERT((id <= MAX_MUTEXID),
            { KERNEL_LOG(KERNEL_ERROR, "For Mutex::Lock current id is %u, max MutexID is %u", id, MAX_MUTEXID); });
        GetBufInternal<pipe, 0>(id);
    }

    template <pipe_t pipe>
    static __aicore__ inline void Unlock(MutexID id)
    {
        ASCENDC_ASSERT((id <= MAX_MUTEXID),
            { KERNEL_LOG(KERNEL_ERROR, "For Mutex::Unlock current id is %u, max MutexID is %u", id, MAX_MUTEXID); });
        RlsBufInternal<pipe, 0>(id);
    }
};

__aicore__ inline MutexID AllocMutexID()
{
    MutexID id = static_cast<uint8_t>(sff0(Internal::g_bufId));
    Internal::g_bufId = sbitset1(Internal::g_bufId, id);
    ASCENDC_ASSERT((id <= MAX_MUTEXID), {
        KERNEL_LOG(KERNEL_ERROR, "current id is %u, max buffer ID allocated is %u", static_cast<uint32_t>(id),
                   static_cast<uint32_t>(MAX_MUTEXID));
    });
    return id;
}

__aicore__ inline void ReleaseMutexID(MutexID id)
{
    ASCENDC_ASSERT((id < MAX_MUTEXID), {
        KERNEL_LOG(KERNEL_ERROR, "current id is %d, which should be larger than or equals to 0, and smaller than %d",
            static_cast<int32_t>(id), MAX_MUTEXID);
    });
    Internal::g_bufId = sbitset0(Internal::g_bufId, id);
}

#endif

__aicore__ inline void SetMaskCount()
{
    SetMaskCountImpl();
}

__aicore__ inline void SetMaskNorm()
{
    SetMaskNormImpl();
}

__aicore__ inline void SetHF32Mode(bool hf32Mode)
{
    SetHF32ModeImpl(hf32Mode);
}

__aicore__ inline void SetHF32TransMode(bool hf32TransMode)
{
    SetHF32TransModeImpl(hf32TransMode);
}

__aicore__ inline void SetMMLayoutTransform(bool mmLayoutMode)
{
    SetMMLayoutTransformImpl(mmLayoutMode);
}

template <uint32_t index>
__aicore__ inline void SetHcclContext(__gm__ uint8_t* context)
{
    if constexpr (index > 1) {
        return;
    }
    g_hcclContextReserved[index] = context;
}

template <uint32_t index>
__aicore__ inline __gm__ uint8_t* __gm__ GetHcclContext(void)
{
    if constexpr (index > 1) {
        return nullptr;
    }
    return g_hcclContextReserved[index];
}

#if defined(__NPU_ARCH__) &&                                                            \
    ((__NPU_ARCH__ == 2201) || (__NPU_ARCH__ == 2002) || (__NPU_ARCH__ == 3002) ||      \
     (__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102))
template <typename T, typename U>
__aicore__ inline void SetAippFunctions(const GlobalTensor<T>& src0, AippInputFormat format, AippParams<U> config)
{
    SetAippFunctionsImpl<PrimT<T>, U>(const_cast<__gm__ PrimT<T>*>(src0.GetPhyAddr()), format, config);
}

template <typename T, typename U>
__aicore__ inline void SetAippFunctions(const GlobalTensor<T>& src0, const GlobalTensor<T>& src1,
                                        AippInputFormat format, AippParams<U> config)
{
    SetAippFunctionsImpl<PrimT<T>, U>(const_cast<__gm__ PrimT<T>*>(src0.GetPhyAddr()),
                                      const_cast<__gm__ PrimT<T>*>(src1.GetPhyAddr()), format, config);
}
#endif // (__NPU_ARCH__ == 2201) || (__NPU_ARCH__ == 2002) || (__NPU_ARCH__ == 3002)
} // namespace AscendC

[[deprecated("NOTICE: SetDumpWorkSpacePtr has been deprecated and will be removed in the next version. "
             "Please do not use it!")]]
__aicore__ inline __gm__ uint8_t* __gm__ SetDumpWorkSpacePtr(__gm__ uint8_t* workspace)
{
    return g_dumpWorkspaceReserved = workspace;
}
[[deprecated("NOTICE: GetDumpWorkSpacePtr has been deprecated and will be removed in the next version. "
             "Please do not use it!")]]
__aicore__ inline __gm__ uint8_t* __gm__ GetDumpWorkSpacePtr()
{
    return g_dumpWorkspaceReserved;
}
#if defined(ASCENDC_CPU_DEBUG)
__aicore__ void SetSysWorkSpacePtr(__gm__ uint8_t* workspace);
#else
[[deprecated(
    "NOTICE: SetSysWorkSpacePtr has been deprecated and will be removed in the next version.")]]
__aicore__ inline void SetSysWorkSpacePtr(__gm__ uint8_t* workspace)
{
    g_sysWorkspaceReserved = workspace;
}
#endif
#endif
